from __future__ import annotations

from datetime import date
from pathlib import Path

import pytest

import polars as pl
import polars.selectors as cs
from polars.exceptions import InvalidOperationError
from polars.testing import assert_frame_equal


@pytest.fixture
def foods_ipc_path() -> Path:
    return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc"


def test_div() -> None:
    df = pl.LazyFrame(
        {
            "a": [10.0, 20.0, 30.0, 40.0, 50.0],
            "b": [-100.5, 7.0, 2.5, None, -3.14],
        }
    )
    with pl.SQLContext(df=df, eager=True) as ctx:
        res = ctx.execute(
            """
            SELECT
              a / b AS a_div_b,
              a // b AS a_floordiv_b,
              SIGN(b) AS b_sign,
            FROM df
            """
        )

    assert_frame_equal(
        pl.DataFrame(
            [
                [-0.0995024875621891, 2.85714285714286, 12.0, None, -15.92356687898089],
                [-1, 2, 12, None, -16],
                [-1.0, 1.0, 1.0, None, -1.0],
            ],
            schema=["a_div_b", "a_floordiv_b", "b_sign"],
        ),
        res,
    )


def test_equal_not_equal() -> None:
    # validate null-aware/unaware equality operators
    df = pl.DataFrame({"a": [1, None, 3, 6, 5], "b": [1, None, 3, 4, None]})

    with pl.SQLContext(frame_data=df) as ctx:
        res = ctx.execute(
            """
            SELECT
              -- not null-aware
              (a = b)  as "1_eq_unaware",
              (a <> b) as "2_neq_unaware",
              (a != b) as "3_neq_unaware",
              -- null-aware
              (a <=> b) as "4_eq_aware",
              (a IS NOT DISTINCT FROM b) as "5_eq_aware",
              (a IS DISTINCT FROM b) as "6_neq_aware",
            FROM frame_data
            """
        ).collect()

    assert res.select(cs.contains("_aware").null_count().sum()).row(0) == (0, 0, 0)
    assert res.select(cs.contains("_unaware").null_count().sum()).row(0) == (2, 2, 2)

    assert res.to_dict(as_series=False) == {
        "1_eq_unaware": [True, None, True, False, None],
        "2_neq_unaware": [False, None, False, True, None],
        "3_neq_unaware": [False, None, False, True, None],
        "4_eq_aware": [True, True, True, False, False],
        "5_eq_aware": [True, True, True, False, False],
        "6_neq_aware": [False, False, False, True, True],
    }


@pytest.mark.parametrize(
    "in_clause",
    [
        "values NOT IN ([0], [3,4], [7,8], [6,6,6])",
        "values IN ([0], [5,6], [1,2], [8,8,8,8])",
        "dt NOT IN ('1950-12-24', '1997-07-05')",
        "dt IN ('2020-10-10', '2077-03-18')",
        "rowid NOT IN (1, 3)",
        "rowid IN (4, 2)",
    ],
)
def test_in_not_in(in_clause: str) -> None:
    df = pl.DataFrame(
        {
            "rowid": [4, 3, 2, 1],
            "values": [[1, 2], [3, 4], [5, 6], [7, 8]],
            "dt": [
                date(2020, 10, 10),
                date(1997, 7, 5),
                date(2077, 3, 18),
                date(1950, 12, 24),
            ],
        }
    )
    res = df.sql(
        f"""
        SELECT "values"
        FROM self
        WHERE {in_clause}
        ORDER BY "rowid" DESC
        """
    )
    assert res.to_dict(as_series=False) == {
        "values": [[1, 2], [5, 6]],
    }


def test_is_between(foods_ipc_path: Path) -> None:
    lf = pl.scan_ipc(foods_ipc_path)

    ctx = pl.SQLContext(foods1=lf, eager=True)
    res = ctx.execute(
        """
        SELECT *
        FROM foods1
        WHERE foods1.calories BETWEEN 22 AND 30
        ORDER BY "calories" DESC, "sugars_g" DESC
        """
    )
    assert res.rows() == [
        ("fruit", 30, 0.0, 5),
        ("vegetables", 30, 0.0, 5),
        ("fruit", 30, 0.0, 3),
        ("vegetables", 25, 0.0, 4),
        ("vegetables", 25, 0.0, 3),
        ("vegetables", 25, 0.0, 2),
        ("vegetables", 22, 0.0, 3),
    ]
    res = ctx.execute(
        """
        SELECT *
        FROM foods1
        WHERE calories NOT BETWEEN 22 AND 30
        ORDER BY "calories" ASC
        """
    )
    assert not any((22 <= cal <= 30) for cal in res["calories"])


def test_logical_not() -> None:
    lf = pl.LazyFrame(
        {
            "valid": [True, False, None, False, True],
            "int_code": [1, 0, 2, None, -1],
        },
    )
    res = lf.sql(
        """
        SELECT
          valid,
          NOT valid AS not_valid,
          int_code,
          NOT int_code AS int_code_zero
        FROM self
        ORDER BY int_code NULLS FIRST
        """
    ).collect()
    # ┌───────┬───────────┬──────────┬───────────────┐
    # │ valid ┆ not_valid ┆ int_code ┆ int_code_zero │
    # │ ---   ┆ ---       ┆ ---      ┆ ---           │
    # │ bool  ┆ bool      ┆ i64      ┆ bool          │
    # ╞═══════╪═══════════╪══════════╪═══════════════╡
    # │ false ┆ true      ┆ null     ┆ null          │
    # │ true  ┆ false     ┆ -1       ┆ false         │
    # │ false ┆ true      ┆ 0        ┆ true          │
    # │ true  ┆ false     ┆ 1        ┆ false         │
    # │ null  ┆ null      ┆ 2        ┆ false         │
    # └───────┴───────────┴──────────┴───────────────┘
    assert res.to_dict(as_series=False) == {
        "valid": [False, True, False, True, None],
        "not_valid": [True, False, True, False, None],
        "int_code": [None, -1, 0, 1, 2],
        "int_code_zero": [None, False, True, False, False],
    }

    # expect failure when applying logical 'NOT' to an incompatible dtype
    for invalid_literal in ("'foo'", "'2026-12-31'::date"):
        with pytest.raises(
            InvalidOperationError,
            match=r"cast.* to Boolean not supported",
        ):
            pl.sql(f"SELECT NOT {invalid_literal}", eager=True)


def test_starts_with() -> None:
    lf = pl.LazyFrame(
        {
            "x": ["aaa", "bbb", "a"],
            "y": ["abc", "b", "aa"],
        },
    )
    assert lf.sql("SELECT x ^@ 'a' AS x_starts_with_a FROM self").collect().rows() == [
        (True,),
        (False,),
        (True,),
    ]
    assert lf.sql("SELECT x ^@ y AS x_starts_with_y FROM self").collect().rows() == [
        (False,),
        (True,),
        (False,),
    ]


@pytest.mark.parametrize("match_float", [False, True])
def test_unary_ops_8890(match_float: bool) -> None:
    with pl.SQLContext(
        df=pl.DataFrame({"a": [-2, -1, 1, 2], "b": ["w", "x", "y", "z"]}),
    ) as ctx:
        in_values = "(-3.0, -1.0, +2.0, +4.0)" if match_float else "(-3, -1, +2, +4)"
        res = ctx.execute(
            f"""
            SELECT *, -(3) as c, (+4) as d
            FROM df WHERE a IN {in_values}
            """
        )
        assert res.collect().to_dict(as_series=False) == {
            "a": [-1, 2],
            "b": ["x", "z"],
            "c": [-3, -3],
            "d": [4, 4],
        }
